// Copyright (c) 2022 wick.zt arg-go is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of
// the Mulan PSL v2. You may obtain a copy of Mulan PSL v2 at:
// http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES
// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

package arg

import (
	"fmt"
	"io"
	"os"
	"reflect"
	"sort"
	"strings"
)

type ArgParser struct {
	// flags holds the map from short and long flags to arg name
	flags map[string]string

	// args holds the map from name to arg
	args map[string]*Arg

	// list holds the sorted arg
	list []*Arg

	Name string

	Output io.Writer
}

// New creates an ArgParser with specified non-null name and a help flag.
// If the help flag is nil, the default usage flag will be used.
func New(name string, help *Flag) (*ArgParser, error) {
	if name == "" {
		return nil, fmt.Errorf("a name is required")
	}

	ap := &ArgParser{
		Name:   name,
		flags:  map[string]string{},
		args:   map[string]*Arg{},
		list:   []*Arg{},
		Output: os.Stderr,
	}

	if help == nil {
		help = DefaultUsageFlag()
	}

	err := ap.AddFlag(help, OrderHighestPriority, ap.defaultHelpTrigger)
	if err != nil {
		return nil, err
	}

	return ap, nil
}

var (
	CommandLine, _ = New(os.Args[0], nil)
)

func (ap *ArgParser) defaultHelpTrigger(v interface{}) error {
	if v.(bool) {
		ap.Usage()
		os.Exit(0)
	}
	return nil
}

// prepare parses a struct value (not the pointer to it).
func (ap *ArgParser) prepare(parentName, parentTag string, v reflect.Value) (err error) {
	vt := v.Type()
	if vt.Kind() != reflect.Struct {
		return &TypeError{NotStruct, vt}
	}

	var a *Arg
	for i := 0; i < vt.NumField(); i++ {
		vtf := vt.Field(i)

		if vtf.Type.Kind() == reflect.Struct {
			ptag, _ := vtf.Tag.Lookup(TagKey)
			if err = ap.prepare(vtf.Name, ptag, v.Field(i)); err != nil {
				return err
			}
			continue
		}

		tag, ok := vtf.Tag.Lookup(TagKey)
		if !ok {
			continue
		}

		tokens, err := parseTag(tag)
		if err != nil {
			return err
		}
		var name string

		// add additional prefix from parent struct-typed struct field.
		if parentName != "" {
			name = parentName + "." + vtf.Name

			// add additional prefix for long flags
			if parentTag != "" {
				parentTokens, err := parseTag(parentTag)
				if err != nil {
					return err
				}
				if len(parentTokens) != 1 || !parentTokens[0].isLongFlag() {
					return &ValueError{
						InvalidFlag,
						"only long flag is supported for nested struct",
						parentTag,
					}
				}
				parentLong := parentTokens[0]

				for j := 0; j < len(tokens); j++ {
					// NOTE if parent long flag prefix is present,
					// inner short flags are disallowed.
					if t := tokens[j]; t.isShortFlag() {
						return &ValueError{
							InvalidFlag,
							"short flag is disallowed in the named-nested struct",
							t,
						}
					} else if t.isLongFlag() {
						tokens[j] = token(fmt.Sprintf("%s-%s", parentLong, t[2:]))
					}
				}
			}
		} else {
			name = vtf.Name
		}

		if a, err = define(name, tokens, v.Field(i)); err != nil {
			return err
		}
		if err = ap.Add(a); err != nil {
			return err
		}
	}

	return nil
}

func (ap *ArgParser) Prepare(v interface{}) (err error) {
	rv := reflect.ValueOf(v)
	if rv.Kind() != reflect.Ptr {
		return &TypeError{NotPointer, rv.Type()}
	}
	if rv.IsNil() {
		return &TypeError{IsNil, rv.Type()}
	}

	return ap.prepare("", "", rv.Elem())
}

func Prepare(v interface{}) (err error) {
	return CommandLine.Prepare(v)
}

func (ap *ArgParser) Add(a *Arg) error {
	if k := a.long; k != "" {
		k = "--" + k
		if _, ok := ap.flags[k]; ok {
			return &ValueError{RedefinedArg, "redefined long flag", k}
		}
		ap.flags[k] = a.name
	}
	if k := a.short; k != 0 {
		key := string([]byte{'-', k})
		if _, ok := ap.flags[key]; ok {
			return &ValueError{RedefinedArg, "redefined short flag", key}
		}
		ap.flags[key] = a.name
	}
	if _, ok := ap.args[a.name]; ok {
		return &ValueError{RedefinedArg, "redefined arg", a.name}
	}
	ap.args[a.name] = a

	ap.list = append(ap.list, a)
	return nil
}

func Add(a *Arg) error {
	return CommandLine.Add(a)
}

func (ap *ArgParser) AddFlag(flag *Flag, order int, trigger TriggerFunc) error {
	arg, err := flag.ToArg(order, trigger)
	if err != nil {
		return err
	}
	return ap.Add(arg)
}

func AddFlag(flag *Flag, order int, trigger TriggerFunc) error {
	return CommandLine.AddFlag(flag, order, trigger)
}

func (ap *ArgParser) parse(s string) (*Arg, error) {
	if s[0] != '-' {
		// TODO: positional args
		return nil, &ValueError{InvalidFlag, "options must start with '-'", s}
	}

	ss := strings.SplitN(s, "=", 2)

	key := ss[0]
	argName, ok := ap.flags[key]
	if !ok {
		return nil, &ValueError{UnknownFlag, "unknown flag", key}
	}
	a, ok := ap.args[argName]
	if !ok {
		return nil, &ValueError{UnknownArg, "unknown arg", argName}
	}

	if a.nargs == 0 {
		if len(ss) == 2 {
			return nil, &ValueError{NoFollowingArg, "no following arg", key}
		}
		if err := a.setValue(!a.Value().(bool)); err != nil {
			return nil, err
		}
		return a, nil
	}

	if len(ss) == 2 {
		if err := a.setValue(ss[1]); err != nil {
			return nil, err
		}
	}

	return a, nil
}

func (ap *ArgParser) sortArgs() {
	sort.Slice(ap.list, func(i, j int) bool {
		if ap.list[i].Order < ap.list[j].Order {
			return true
		}
		if ap.list[i].Order > ap.list[j].Order {
			return false
		}
		return ap.list[i].name < ap.list[j].name
	})
}

func (ap *ArgParser) Parse(args []string) (err error) {
	ap.sortArgs()

	for i := 0; i < len(args); i++ {
		s := args[i]
		if len(s) < 2 {
			return &ValueError{InvalidValueRepr, "arg repr string too short", s}
		}

		a, err := ap.parse(s)
		if err != nil {
			return err
		}

		if a.final {
			continue
		}
		if i >= len(args)-1 {
			return &ValueError{MissingFollowingArg, "missing value for arg", a.name}
		}
		if err = a.setValue(args[i+1]); err != nil {
			return err
		}

		i++
	}

	for _, a := range ap.list {
		if err = a.trigger(); err != nil {
			return err
		}
	}

	return nil
}

func Parse() error {
	return CommandLine.Parse(os.Args[1:])
}

func (ap *ArgParser) Usage() {
	ap.sortArgs()

	if ap.Name == "" {
		fmt.Fprintf(ap.Output, "Usage:")
	} else {
		fmt.Fprintf(ap.Output, "Usage of %s:", ap.Name)
	}
	fmt.Fprintf(ap.Output, "\n    %s", os.Args[0])

	for _, a := range ap.list {
		var example string
		if a.short != 0 {
			example = string([]byte{'-', a.short})
		} else if a.long != "" {
			example = "--" + a.long
		}
		if a.nargs != 0 {
			example += " " + a.name
		}
		if !a.Required() {
			example = "[" + example + "]"
		}
		fmt.Fprintf(ap.Output, " %s", example)
	}
	fmt.Fprintf(ap.Output, "\n\n")

	sb := &strings.Builder{}
	showFlags := false
	for _, a := range ap.list {
		if !showFlags && a.Order < OrderDefault {
			fmt.Fprintln(sb, "Flags:")
			showFlags = true
		}
		if showFlags && a.Order >= OrderDefault {
			fmt.Fprintln(sb, "\nOptions:")
			showFlags = false
		}

		if a.short != 0 && a.long != "" {
			fmt.Fprintf(sb, "    -%c,--%s", a.short, a.long)
		} else if a.short != 0 {
			fmt.Fprintf(sb, "    -%c", a.short)
		} else if a.long != "" {
			fmt.Fprintf(sb, "    --%s", a.long)
		}
		switch a.nargs {
		case 0:
			// doing nothing
		case 1:
			sb.WriteByte(' ')
			sb.WriteString(strings.ToUpper(a.name))
		case 2:
			// not supported yet
		}
		sb.WriteString("\n\t")

		sb.WriteString(a.usage)
		if a.Required() {
			fmt.Fprint(sb, " (required)")
		} else if a.defv != nil {
			fmt.Fprintf(sb, " (default: %v)", a.defv)
		}

		fmt.Fprintln(ap.Output, sb.String())
		sb.Reset()
	}
	fmt.Fprintln(ap.Output)
}

func Usage() {
	CommandLine.Usage()
}
